import argparse
import logging
import sys
import os
from configparser import ConfigParser

from torch import optim
import torch

from disvae import init_specific_model, Trainer, Evaluator
from disvae.models.anneal import get_anneal
from disvae.utils.modelIO import save_model, load_model, load_metadata, save_metadata
from disvae.models.losses import LOSSES, RECON_DIST, get_loss_f
from disvae.models.vae import MODELS

from utils.datasets import get_dataloaders, DATASETS
from utils.helpers import (create_safe_directory, get_device, set_seed, get_n_param,
                           get_config_section, update_namespace_, FormatterNoDuplicate)
from utils.visualize import GifTraversalsTraining

CONFIG_FILE = "hyperparam.ini"
LOG_LEVELS = list(logging._levelToName.values())


def parse_arguments(args_to_parse):
    """Parse the command line arguments.

    Parameters
    ----------
    args_to_parse: list of str
        Arguments to parse (splitted on whitespaces).
    """
    default_config = get_config_section([CONFIG_FILE], "Custom")

    description = "PyTorch implementation and evaluation of disentangled Variational AutoEncoders and metrics."
    parser = argparse.ArgumentParser(description=description,
                                     formatter_class=FormatterNoDuplicate)

    # General options
    general = parser.add_argument_group('General options')
    general.add_argument('--name', type=str, default=None,
                         help="Name of the model for storing and loading purposes.")
    general.add_argument('-L', '--log_level', help="Logging levels.",
                         default=default_config['log_level'], choices=LOG_LEVELS)
    general.add_argument('--no_progress_bar', action='store_true',
                         default=default_config['no_progress_bar'],
                         help='Disables progress bar.')
    general.add_argument('--device',
                         default=default_config['device'],
                         help='cuda:n,or cpu')
    general.add_argument('-s', '--seed', type=int, default=default_config['seed'],
                         help='Random seed. Can be `None` for stochastic behavior.')
    general.add_argument('--monitor', action='store_true',
                         help='monitor the training process')

    general.add_argument('-g', '--group', type=int,default=5)
    general.add_argument('--simultaneous', '-ss', action='store_true',
                         help='train encoders simultaneously')
    general.add_argument('--base', type=str,default="",
                         help='Choose base model for F-VAE. A list of pressures '
                              'eg. 100,40, meaning first phase has pressure 100 then 40.')

    general.add_argument('--beta', type=float, default=1,
                         help='Strength of pressure.')

    general.add_argument('--lr_decay', type=float, default=0.05,
                         help='Strength of pressure.')
    # Learning options
    training = parser.add_argument_group('Training specific options')
    training.add_argument('--checkpoint_every',
                          type=int, default=default_config['checkpoint_every'],
                          help='Save a checkpoint of the trained model every n epoch.')
    training.add_argument('-d', '--dataset', help="Path to training data.",
                          default=default_config['dataset'])

    training.add_argument('-e', '--epochs', type=int,
                          default=default_config['epochs'],
                          help='Maximum number of epochs to run for.')
    training.add_argument('-b', '--batch_size', type=int,
                          default=default_config['batch_size'],
                          help='Batch size for training.')
    training.add_argument('--lr', type=float, default=default_config['lr'],
                          help='Learning rate.')

    # Model Options
    model = parser.add_argument_group('Model specfic options')
    model.add_argument('-m', '--model_type',
                       default=default_config['model'], choices=MODELS,
                       help='Type of encoder and decoder to use.')
    model.add_argument('-z', '--latent_dim', type=int,
                       default=default_config['latent_dim'],
                       help='Dimension of the latent variable.')
    model.add_argument('-l', '--loss',
                       default=default_config['loss'], choices=LOSSES,
                       help="Type of VAE loss function to use.")
    model.add_argument('-r', '--rec_dist', default=default_config['rec_dist'],
                       choices=RECON_DIST,
                       help="Form of the likelihood ot use for each pixel.")

    # Loss Specific Options
    betaH = parser.add_argument_group('BetaH specific parameters')

    betaH.add_argument('--betaH_B', type=float,
                       default=default_config['betaH_B'],
                       help="Weight of the KL (beta in the paper).")

    betaB = parser.add_argument_group('BetaB specific parameters')
    betaB.add_argument('--betaB_initC', type=float,
                       default=default_config['betaB_initC'],
                       help="Starting annealed capacity.")
    betaB.add_argument('--betaB_finC', type=float,
                       default=default_config['betaB_finC'],
                       help="Final annealed capacity.")
    betaB.add_argument('--betaB_G', type=float,
                       default=default_config['betaB_G'],
                       help="Weight of the KL divergence term (gamma in the paper).")

    factor = parser.add_argument_group('factor VAE specific parameters')
    factor.add_argument('--factor_G', type=float,
                        default=default_config['factor_G'],
                        help="Weight of the TC term (gamma in the paper).")
    factor.add_argument('--lr_disc', type=float,
                        default=default_config['lr_disc'],
                        help='Learning rate of the discriminator.')

    btcvae = parser.add_argument_group('beta-tcvae specific parameters')
    btcvae.add_argument('--btcvae_A', type=float,
                        default=default_config['btcvae_A'],
                        help="Weight of the MI term (alpha in the paper).")
    btcvae.add_argument('--btcvae_G', type=float,
                        default=default_config['btcvae_G'],
                        help="Weight of the dim-wise KL term (gamma in the paper).")
    btcvae.add_argument('--btcvae_B', type=float,
                        default=default_config['btcvae_B'],
                        help="Weight of the TC term (beta in the paper).")

    # Learning options
    evaluation = parser.add_argument_group('Evaluation specific options')
    evaluation.add_argument('--is_eval_only', type=bool,
                            default=default_config['is_eval_only'],
                            help='Whether to only evaluate using precomputed model `name`.')
    evaluation.add_argument('--is_metrics', type=bool,
                            default=default_config['is_metrics'],
                            help="Whether to compute the disentangled metrcics. Currently only possible with `dsprites` as it is the only dataset with known true factors of variations.")
    evaluation.add_argument('--no_test', type=bool,
                            default=default_config['no_test'],
                            help="Whether not to compute the test losses.`")
    evaluation.add_argument('--eval_batchsize', type=int,
                            default=default_config['eval_batchsize'],
                            help='Batch size for evaluation.')

    # Anneal options
    anneal_opt = parser.add_argument_group('Anneal specific options')
    anneal_opt.add_argument('--anneal_name', default='constant',
                            choices=['constant', 'monotonic', 'cyclic', 'cosine'])
    anneal_opt.add_argument('--anneal_l', default=1, type=float)
    anneal_opt.add_argument('--anneal_r', default=1, type=float)
    anneal_opt.add_argument('--anneal_count', default=1, type=float)

    args = parser.parse_args(args_to_parse)

    dataset = args.dataset
    model = args.loss
    # common_data = get_config_section([CONFIG_FILE], "Common_{}".format(dataset))
    # update_namespace_(args, common_data)

    common_model = get_config_section([CONFIG_FILE], "Common_{}".format(model))
    update_namespace_(args, common_model)

    return args


def main(args):
    """Main train and evaluation function.

    Parameters
    ----------
    args: argparse.Namespace
        Arguments
    """
    import wandb
    formatter = logging.Formatter('%(asctime)s %(levelname)s - %(funcName)s: %(message)s',
                                  "%H:%M:%S")
    logger = logging.getLogger(__name__)
    logger.setLevel(args.log_level.upper())
    stream = logging.StreamHandler()
    stream.setLevel(args.log_level.upper())
    stream.setFormatter(formatter)
    logger.addHandler(stream)

    set_seed(args.seed)
    device = get_device(args.device)


    group = (args.group)
    if not args.is_eval_only :
        if args.loss == "factor":
            logger.info(
                "FactorVae needs 2 batches per iteration. To replicate this behavior while being consistent, we double the batch size and the the number of epochs.")
            args.batch_size *= 2
            args.epochs *= 2

        # PREPARES DATA
        train_loader = get_dataloaders(args.dataset,
                                       num_workers=4,
                                       batch_size=args.batch_size)
        logger.info("Train {} with {} samples".format(args.dataset, len(train_loader.dataset)))

        # PREPARES MODEL
        observation_shape = train_loader.dataset.observation_shape
        args.img_size = [observation_shape[2], observation_shape[0],
                         observation_shape[1]]  # train_loader.dataset.observation_shape
        model = init_specific_model(args.model_type, args.img_size, args.latent_dim, group)
        logger.info('Num parameters in model: {}'.format(get_n_param(model)))

        # TRAINS
        optimizer = optim.Adam(model.decoder.parameters(), lr=args.lr, weight_decay=0)
        if args.simultaneous:
            optimizer.add_param_group({'params': model.encoders.parameters(),
                                           'lr':args.lr,
                                           'weight_decay':0})
        else:
            if args.base == '':
                l=0
            else:
                l = len(args.base.split(','))
                model.load_state_dict(torch.load(os.path.join('multi_step',args.base+'.pt')))
            print(l)
            # model.phase=l
            # 给予不同的lr
            for encoder_idx, m in enumerate(model.encoders):
                if encoder_idx == l:
                    optimizer.add_param_group({'params': model.encoders[l].parameters(),
                                               'lr': args.lr,
                                               'weight_decay': 0})
                else:
                    optimizer.add_param_group({'params': m.parameters(),
                                           'lr':args.lr*args.lr_decay,
                                           'weight_decay':0})

        # anneal
        # perid = int(len(train_loader) * args.epochs / args.anneal_count)
        b = args.beta
        anneal = get_anneal(args.anneal_name, 200000, args.anneal_l * b, args.anneal_r * b)

        model = model.to(device)  # make sure trainer and viz on same device

        if args.name:
            exp_dir = args.name
        else:
            wandb.init(config=args, project='fractionVAE')
            exp_dir = wandb.run.dir
        logger.info("Root directory for saving and loading experiments: {}".format(exp_dir))

        gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir)
        loss_f = get_loss_f(args.loss, anneal,
                            n_data=len(train_loader.dataset),
                            **vars(args))

        # loss_f.weight = (torch.ones(args.latent_dim//args.group,args.group)
        #                  *torch.Tensor([8,4,2,0.25,0.125])).to(device).reshape(-1)
        trainer = Trainer(model, optimizer, loss_f,
                          device=device,
                          logger=logger,
                          save_dir=exp_dir,
                          is_progress_bar=not args.no_progress_bar,
                          gif_visualizer=gif_visualizer,)
        #
        # # SAVE MODEL AND EXPERIMENT INFORMATION
        save_metadata(vars(args), exp_dir)
        #
        trainer(train_loader,
                epochs=args.epochs,
                checkpoint_every=args.checkpoint_every, )
        gif_visualizer.save_reset()

        if args.base =='':
            n_base = str(int(args.beta))
        else:
            n_base = args.base + ',' + str(int(args.beta))
        if not args.simultaneous:
            save_model(trainer.model, 'multi_step', filename=f'{n_base}.pt')


        # SAVE MODEL AND EXPERIMENT INFORMATION
        save_model(trainer.model, exp_dir, metadata=vars(args))
    else:
        exp_dir = args.name
        model = load_model(exp_dir, device)


    if args.is_metrics or not args.no_test:
        metadata = load_metadata(exp_dir)
        # TO-DO: currently uses train datatset
        test_loader = get_dataloaders(metadata["dataset"],
                                      batch_size=args.eval_batchsize,
                                      num_workers=4,
                                      shuffle=False)
        # anneal
        perid = int(len(test_loader) * args.epochs / args.anneal_count)
        anneal = get_anneal('constant', perid, 1, 1)

        loss_f = get_loss_f('btcvae', anneal,
                            n_data=len(test_loader.dataset),
                            **vars(args))

        evaluator = Evaluator(model, loss_f,
                              device=device,
                              logger=logger,
                              save_dir=exp_dir,
                              is_progress_bar=not args.no_progress_bar)

        evaluator(test_loader, is_metrics=args.is_metrics, is_losses=not args.no_test)
        wandb.join()


if __name__ == '__main__':
    args = parse_arguments(sys.argv[1:])
    main(args)
